import numpy as np
import d4rl
import os
import gym
import random

for e in ['halfcheetah', 'hopper', 'walker2d']:
    for level in ['medium', 'medium-replay', 'medium-expert']:
        for percent in [0.1, 0.01, 0.5]:
            dataset_name = f'{e}-{level}-v2'
            env = gym.make(dataset_name)
            dataset = d4rl.qlearning_dataset(env)
            
   
            not_done = 1. - dataset['terminals'].reshape(-1,1)
            size = not_done.shape[0]
            # compute time limit
            dones_float = np.zeros_like(dataset['rewards'])
            trajs = [[]]
            for i in range(len(dones_float) - 1):
                trajs[-1].append(i)
                if np.linalg.norm(dataset['observations'][i + 1] -
                                dataset['next_observations'][i]
                                ) > 1e-6 or dataset['terminals'][i] == 1.0:
                    dones_float[i] = 1
                    trajs.append([])
          
            trajs = random.sample(trajs, k=int(len(trajs)*percent)) # without replacement
            trajs.sort(key=lambda x : x[0])    
            idx =  [i for traj in trajs for i in traj]            
            idx = np.array(idx)      
            path = f'../traj_index/{dataset_name}_{percent}.npy'

            np.save(path, idx)